/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

KERNEL_SPEC(compute_midstate_kernel_0, BLOCK_SIZE)
(
    GLOBAL uint32_t* midstate_ypx,
    GLOBAL uint32_t* midstate_ymx,
    GLOBAL uint32_t* zs,
    GLOBAL Bn256* midstate_bn,
    GLOBAL const uint64_t* __restrict__ seeds,
    GLOBAL const Bn256 addend_sum,
    GLOBAL const struct AffineNielsPoint* __restrict__ basepoint_mul_fold_table_
) {
    const size_t gid = global_id();
    const size_t lid = local_id();
    const size_t gsize = global_size();
    const size_t lsize = local_size();

    LOCAL struct AffineNielsPoint basepoint_mul_fold_table[256];
    load_fold_table(basepoint_mul_fold_table, basepoint_mul_fold_table_);

    struct Rng rng;
    rng_init(&rng, seeds[gid], 0);
    uint32_t r = rng_rand(&rng);

    Bn256 m;
    rng_fill_bytes(&rng, m, 32);

    Bn256 chosen_l, m_max, m_min;

    if(r % 9 < 2) {
        bn_copy(chosen_l, l_4);
        bn_sub(m_max, l_minus_one, addend_sum);
        bn_copy(m_min, addend_sum);
    } else if(r % 9 < 4) {
        bn_copy(chosen_l, l_5);
        bn_sub(m_max, l_minus_one, addend_sum);
        bn_copy(m_min, addend_sum);
    } else if(r % 9 < 6) {
        bn_copy(chosen_l, l_6);
        bn_sub(m_max, l_minus_one, addend_sum);
        bn_copy(m_min, addend_sum);
    } else {
        bn_copy(chosen_l, l_7);
        bn_copy(m_max, l_7_ub);
        bn_copy(m_min, addend_sum);
    }
    // TODO : choose l_3 with very small prob
    // else {
    //     chosen_l = &l_3;
    //     m_max = l_minus_one - *addend_sum;
    //     m_min = l_3_lb + *addend_sum;
    // }

    const uint8_t m_max_lz = bn_clz(m_max);
    const uint8_t m_min_lz = bn_clz(m_min);

    uint8_t to8 = 8 - chosen_l[0] % 8;

    m[0] = (m[0] & 0xf8) + to8;

    bn_clear_leading_bits(m, m_max_lz);

    if(bn_gt(m, m_max))
        bn_clear_leading_bits(m, m_max_lz + 1);

    if(bn_lt(m, m_min)) {
        uint8_t byte = 31 - (m_min_lz - 1) / 8;
        uint8_t bit = 7 - (m_min_lz - 1) % 8;
        m[byte] |= (1 << bit);
    }

    bn_copy(midstate_bn[gid], m);

    SYNCTHREADS();

    struct EdwardsPoint p;
    scalar_mul_folds(&p, (const uint32_t*)m, basepoint_mul_fold_table);

    fe_store_global(midstate_ypx, p.x);
    fe_store_global(midstate_ymx, p.y);
    fe_store_global(zs, p.z);
}

KERNEL_SPEC(compute_midstate_kernel_1, BLOCK_SIZE)
(
    GLOBAL uint32_t* midstate_ypx,
    GLOBAL uint32_t* midstate_ymx,
    GLOBAL uint32_t* midstate_xy,
    GLOBAL const uint32_t* __restrict__ zs_inv
) {
    FieldElement x, y, z, t;
    fe_load_global(x, midstate_ypx);
    fe_load_global(y, midstate_ymx);
    fe_load_global(z, zs_inv);

    fe_mul(x, x, z);
    fe_mul(y, y, z);
    fe_mul(t, x, y);

    fe_store_global(midstate_xy, t);

    fe_add(t, y, x);
    fe_store_global(midstate_ypx, t);

    fe_sub(t, y, x);
    fe_store_global(midstate_ymx, t);
}